{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert import tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file='MASS.wordpiece', do_lower_case=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['saya', 'suka', 'makan', 'an', '##jen', '##g']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.tokenize('saya suka makan anjeng')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "rng = random.Random(12345)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 300\n",
    "dupe_factor = 5\n",
    "short_seq_prob = 0.1\n",
    "masked_lm_prob = 0.2\n",
    "max_predictions_per_seq = 20\n",
    "eos_id = 1\n",
    "\n",
    "vocab_words = list(tokenizer.vocab.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../pure-text/dumping-instagram.txt',\n",
       " '../pure-text/parliament-text.txt',\n",
       " '../pure-text/dumping-parliament.txt',\n",
       " '../pure-text/dumping-twitter.txt',\n",
       " '../pure-text/normalization-100000.txt',\n",
       " '../pure-text/iium.wordpiece-vocab.txt',\n",
       " '../pure-text/dumping-iium.txt',\n",
       " '../pure-text/twitter-normalization-100000.txt',\n",
       " '../pure-text/dumping-wiki.txt',\n",
       " '../pure-text/dumping-news.txt',\n",
       " '../pure-text/dumping-watpadd.txt',\n",
       " '../pure-text/dumping-pdf.txt',\n",
       " '../pure-text/dumping-common-crawl.txt']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "files = glob('../pure-text/*.txt')\n",
    "files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "\n",
    "def create_single_instances(\n",
    "    input_files,\n",
    "    tokenizer,\n",
    "):\n",
    "    all_documents = []\n",
    "    for input_file in input_files:\n",
    "        with open(input_file,'r') as f:\n",
    "            for line in tqdm(f):\n",
    "                line = tokenization.convert_to_unicode(line)\n",
    "                if not line:\n",
    "                    break\n",
    "                line = line.strip()\n",
    "                if line:\n",
    "                    tokens = tokenizer.tokenize(line)\n",
    "                    if len(tokens) > 10:\n",
    "                        all_documents.append(tokens)\n",
    "\n",
    "    all_documents = [x for x in all_documents if x]\n",
    "    return all_documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_MaskedLanguagePairDataset(row):\n",
    "    max_num_tokens = max_seq_length\n",
    "    target_seq_length = max_num_tokens\n",
    "    if rng.random() < short_seq_prob:\n",
    "        target_seq_length = rng.randint(2, max_num_tokens)\n",
    "    tokens = row[:max_num_tokens]\n",
    "    \n",
    "    output_tokens = list(tokens)\n",
    "    \n",
    "    cand_indexes = []\n",
    "    for (i, token) in enumerate(tokens):\n",
    "        if token == '[CLS]' or token == '[SEP]':\n",
    "            continue\n",
    "        if (\n",
    "            len(cand_indexes) >= 1\n",
    "            and token.startswith('##')\n",
    "        ):\n",
    "            cand_indexes[-1].append(i)\n",
    "        else:\n",
    "            cand_indexes.append([i])\n",
    "\n",
    "    num_to_predict = min(\n",
    "        max_predictions_per_seq,\n",
    "        max(1, int(round(len(tokens) * masked_lm_prob))),\n",
    "    )\n",
    "    start = random.randint(1, len(cand_indexes) - num_to_predict)\n",
    "\n",
    "    label, input_decoder = [], []\n",
    "    covered_indexes = set()\n",
    "    for index_set in cand_indexes[start:]:\n",
    "        if len(label) >= num_to_predict:\n",
    "            break\n",
    "        is_any_index_covered = False\n",
    "        for index in index_set:\n",
    "            if index in covered_indexes:\n",
    "                is_any_index_covered = True\n",
    "                break\n",
    "        if is_any_index_covered:\n",
    "            continue\n",
    "        for index in index_set:\n",
    "            covered_indexes.add(index)\n",
    "\n",
    "            masked_token = None\n",
    "            # 80% of the time, replace with [MASK]\n",
    "            if rng.random() < 0.8:\n",
    "                masked_token = '[MASK]'\n",
    "            else:\n",
    "                # 10% of the time, keep original\n",
    "                if rng.random() < 0.5:\n",
    "                    masked_token = tokens[index]\n",
    "                # 10% of the time, replace with random word\n",
    "                else:\n",
    "                    masked_token = vocab_words[\n",
    "                        rng.randint(0, len(vocab_words) - 1)\n",
    "                    ]\n",
    "\n",
    "            output_tokens[index] = masked_token\n",
    "            label.append(tokens[index])\n",
    "            input_decoder.append(tokens[index - 1])\n",
    "            \n",
    "    return output_tokens, input_decoder, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "def create_pair_instances(\n",
    "    input_files,\n",
    "    tokenizer,\n",
    "):\n",
    "    all_documents = []\n",
    "    for input_file in input_files:\n",
    "        with open(input_file) as fopen:\n",
    "            data = json.load(fopen)\n",
    "        for i in tqdm(range(len(data['left']))):\n",
    "            line_l = tokenization.convert_to_unicode(data['left'][i])\n",
    "            line_r = tokenization.convert_to_unicode(data['right'][i])\n",
    "            \n",
    "            line_l = line_l.strip()\n",
    "            line_r = line_r.strip()\n",
    "            if line_l and line_r:\n",
    "                tokens_l = tokenizer.tokenize(line_l)\n",
    "                tokens_r = tokenizer.tokenize(line_r)\n",
    "                if len(tokens_l) < max_seq_length and len(tokens_r) < max_seq_length:\n",
    "                    all_documents.append((tokens_l, tokens_r))\n",
    "\n",
    "    return all_documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# r_pair = create_pair_instances(['en-ms.json'], tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_NoisyLanguagePairDataset(row):\n",
    "    max_num_tokens = max_seq_length\n",
    "    target_seq_length = max_num_tokens\n",
    "    if rng.random() < short_seq_prob:\n",
    "        target_seq_length = rng.randint(2, max_num_tokens)\n",
    "    tokens_l = row[0][:max_num_tokens]\n",
    "    tokens_r = row[1][:max_num_tokens]\n",
    "    \n",
    "    cand_indexes_l = []\n",
    "    for (i, token) in enumerate(tokens_l):\n",
    "        if token == '[CLS]' or token == '[SEP]':\n",
    "            continue\n",
    "        if (\n",
    "            len(cand_indexes_l) >= 1\n",
    "            and token.startswith('##')\n",
    "        ):\n",
    "            cand_indexes_l[-1].append(i)\n",
    "        else:\n",
    "            cand_indexes_l.append([i])\n",
    "            \n",
    "    cand_indexes_r = []\n",
    "    for (i, token) in enumerate(tokens_r):\n",
    "        if token == '[CLS]' or token == '[SEP]':\n",
    "            continue\n",
    "        if (\n",
    "            len(cand_indexes_r) >= 1\n",
    "            and token.startswith('##')\n",
    "        ):\n",
    "            cand_indexes_r[-1].append(i)\n",
    "        else:\n",
    "            cand_indexes_r.append([i])\n",
    "            \n",
    "    rng.shuffle(cand_indexes_l)\n",
    "    rng.shuffle(cand_indexes_r)\n",
    "\n",
    "    num_to_predict = min(\n",
    "        max_predictions_per_seq,\n",
    "        max(1, int(round(len(tokens_l) * masked_lm_prob))),\n",
    "    )\n",
    "    \n",
    "    masked_lms = []\n",
    "    output_tokens_l = list(tokens_l)\n",
    "    covered_indexes = set()\n",
    "    for index_set in cand_indexes_l:\n",
    "        if len(masked_lms) >= num_to_predict:\n",
    "            break\n",
    "        is_any_index_covered = False\n",
    "        for index in index_set:\n",
    "            if index in covered_indexes:\n",
    "                is_any_index_covered = True\n",
    "                break\n",
    "        if is_any_index_covered:\n",
    "            continue\n",
    "        for index in index_set:\n",
    "            covered_indexes.add(index)\n",
    "\n",
    "            masked_token = None\n",
    "            # 80% of the time, replace with [MASK]\n",
    "            if rng.random() < 0.8:\n",
    "                masked_token = '[MASK]'\n",
    "            else:\n",
    "                # 10% of the time, keep original\n",
    "                if rng.random() < 0.5:\n",
    "                    masked_token = tokens_l[index]\n",
    "                # 10% of the time, replace with random word\n",
    "                else:\n",
    "                    masked_token = vocab_words[\n",
    "                        rng.randint(0, len(vocab_words) - 1)\n",
    "                    ]\n",
    "\n",
    "            output_tokens_l[index] = masked_token\n",
    "            masked_lms.append(tokens_l[index])\n",
    "    \n",
    "    num_to_predict = min(\n",
    "        max_predictions_per_seq,\n",
    "        max(1, int(round(len(tokens_r) * masked_lm_prob))),\n",
    "    )\n",
    "    \n",
    "    output_tokens_r = list(tokens_r)\n",
    "    label, input_decoder = [], []\n",
    "    covered_indexes = set()\n",
    "    for index_set in cand_indexes_r:\n",
    "        if len(label) >= num_to_predict:\n",
    "            break\n",
    "        is_any_index_covered = False\n",
    "        for index in index_set:\n",
    "            if index in covered_indexes:\n",
    "                is_any_index_covered = True\n",
    "                break\n",
    "        if is_any_index_covered:\n",
    "            continue\n",
    "        for index in index_set:\n",
    "            covered_indexes.add(index)\n",
    "\n",
    "            masked_token = None\n",
    "            # 80% of the time, replace with [MASK]\n",
    "            if rng.random() < 0.8:\n",
    "                masked_token = '[MASK]'\n",
    "            else:\n",
    "                # 10% of the time, keep original\n",
    "                if rng.random() < 0.5:\n",
    "                    masked_token = tokens_r[index]\n",
    "                # 10% of the time, replace with random word\n",
    "                else:\n",
    "                    masked_token = vocab_words[\n",
    "                        rng.randint(0, len(vocab_words) - 1)\n",
    "                    ]\n",
    "\n",
    "            output_tokens_r[index] = masked_token\n",
    "            label.append(tokens_r[index])\n",
    "            \n",
    "    return output_tokens_l, output_tokens_r, tokens_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import tensorflow as tf\n",
    "\n",
    "maxlen = max_seq_length\n",
    "\n",
    "def create_int_feature(values):\n",
    "    feature = tf.train.Feature(\n",
    "        int64_list = tf.train.Int64List(value = list(values))\n",
    "    )\n",
    "    return feature\n",
    "\n",
    "def to_tfrecord(rows, filename):\n",
    "    input_encoders, input_decoders, labels = [], [], []\n",
    "    \n",
    "    for i in tqdm(range(len(rows))):\n",
    "        input_encoder, input_decoder, label = rows[0]\n",
    "        input_encoder = tokenizer.convert_tokens_to_ids(input_encoder)\n",
    "        input_decoder = tokenizer.convert_tokens_to_ids(input_decoder) + [eos_id]\n",
    "        label = tokenizer.convert_tokens_to_ids(label) + [eos_id]\n",
    "        input_encoder = input_encoder + [0] * (maxlen - len(input_encoder))\n",
    "        input_decoder = input_decoder + [0] * (maxlen - len(input_decoder))\n",
    "        label = label + [0] * (maxlen - len(label))\n",
    "        input_encoders.append(input_encoder)\n",
    "        input_decoders.append(input_decoder)\n",
    "        labels.append(label)\n",
    "    \n",
    "    r = tf.python_io.TFRecordWriter(f'{filename}.tfrecord')\n",
    "    for i in tqdm(range(len(labels))):\n",
    "        features = collections.OrderedDict()\n",
    "        features['input_encoder'] = create_int_feature(input_encoders[i])\n",
    "        features['input_decoder'] = create_int_feature(input_decoders[i])\n",
    "        features['y'] = create_int_feature(labels[i])\n",
    "        tf_example = tf.train.Example(\n",
    "            features = tf.train.Features(feature = features)\n",
    "        )\n",
    "        r.write(tf_example.SerializeToString())\n",
    "    r.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "def single_pair(files, dupe_factor = 3):\n",
    "    for file in files:\n",
    "        print(file)\n",
    "        f = os.path.split(file)[1]\n",
    "        r = create_single_instances([file], tokenizer = tokenizer)\n",
    "        results = []\n",
    "        for row in r:\n",
    "            for _ in range(dupe_factor):\n",
    "                try:\n",
    "                    results.append(process_MaskedLanguagePairDataset(row))\n",
    "                except:\n",
    "                    pass\n",
    "        to_tfrecord(results, f)\n",
    "        \n",
    "def double_pair(files, dupe_factor = 20):\n",
    "    for file in files:\n",
    "        print(file)\n",
    "        f = os.path.split(file)[1]\n",
    "        r = create_pair_instances([file], tokenizer = tokenizer)\n",
    "        results = []\n",
    "        for row in r:\n",
    "            for _ in range(dupe_factor):\n",
    "                try:\n",
    "                    results.append(process_NoisyLanguagePairDataset(row))\n",
    "                except:\n",
    "                    pass\n",
    "        to_tfrecord(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected = ['../pure-text/dumping-parliament.txt',\n",
    "           '../pure-text/dumping-iium.txt',\n",
    "           '../pure-text/dumping-wiki.txt',\n",
    "           '../pure-text/dumping-news.txt',\n",
    "           '../pure-text/dumping-watpadd.txt',\n",
    "           '../pure-text/dumping-pdf.txt']\n",
    "\n",
    "pairs = ['en-ms.json', 'ms-en.json']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "573it [00:00, 5720.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../pure-text/dumping-parliament.txt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "960868it [02:04, 7706.33it/s]\n",
      "100%|██████████| 1219653/1219653 [00:36<00:00, 33380.67it/s]\n",
      "100%|██████████| 1219653/1219653 [04:41<00:00, 4330.52it/s]\n",
      "444it [00:00, 4432.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../pure-text/dumping-iium.txt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1139838it [03:22, 5637.71it/s]\n",
      "100%|██████████| 2102775/2102775 [01:02<00:00, 33632.71it/s]\n",
      "100%|██████████| 2102775/2102775 [08:06<00:00, 4326.24it/s]\n",
      "279it [00:00, 2785.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../pure-text/dumping-wiki.txt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2050801it [09:44, 3509.09it/s]\n",
      "100%|██████████| 4765857/4765857 [02:38<00:00, 30111.50it/s]\n",
      "100%|██████████| 4765857/4765857 [18:30<00:00, 4290.66it/s]\n",
      "263it [00:00, 2618.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../pure-text/dumping-news.txt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1897306it [11:02, 2863.58it/s]\n",
      "100%|██████████| 4799394/4799394 [02:07<00:00, 37591.00it/s]\n",
      "100%|██████████| 4799394/4799394 [18:31<00:00, 4319.05it/s]\n",
      "557it [00:00, 5569.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../pure-text/dumping-watpadd.txt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1500469it [04:04, 6138.31it/s]\n",
      "100%|██████████| 3223416/3223416 [01:34<00:00, 34251.87it/s]\n",
      "100%|██████████| 3223416/3223416 [12:20<00:00, 4351.92it/s]\n",
      "611it [00:00, 6104.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../pure-text/dumping-pdf.txt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "651956it [01:51, 5866.14it/s]\n",
      "100%|██████████| 1325601/1325601 [00:40<00:00, 32677.50it/s]\n",
      "100%|██████████| 1325601/1325601 [05:04<00:00, 4359.22it/s]\n"
     ]
    }
   ],
   "source": [
    "single_pair(selected)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
